import torch.nn as nn
import torch
from torch.distributions.normal import Normal
from functools import partial

from model.net.transformer import Block as TransformerBlock, GPTConfig
from model.net.mlp import MLP
from model.net.actor import Actor
from model.net.critic import Critic
from model.net.custom import DiscreteEmbeddingLayer
from tools.init_weights import init_weights


def truncated_normal(mean_raw, std, lower, upper, eps=1e-6):

    mean = lower + (upper - lower) * torch.sigmoid(mean_raw)
    std = std.clamp(min=eps)
    a_norm = (lower - mean) / std
    b_norm = (upper - mean) / std

    normal_dist = Normal(torch.zeros_like(mean), torch.ones_like(std))
    cdf_a = normal_dist.cdf(a_norm)
    cdf_b = normal_dist.cdf(b_norm)
    Z = (cdf_b - cdf_a).clamp(min=eps)

    u = torch.rand_like(mean) * (cdf_b - cdf_a) + cdf_a
    u = u.clamp(min=eps, max=1.0 - eps)
    sample = mean + std * normal_dist.icdf(u)

    log_prob = normal_dist.log_prob((sample - mean) / std) - torch.log(std) - torch.log(Z)

    entropy = 0.5 * torch.log(2 * torch.pi * std ** 2 + eps) + 0.5 - torch.log(Z)

    return sample, log_prob, entropy, mean


class PPOContinuous(nn.Module):
    def __init__(self, envs, quantile_num, window_size,
                 confidence=0.95, embed_dim=256, depth=2, middle_dim=256,
                 num_head=8, group_size=3, flash=True,
                 method="transformer", device="cpu"):
        super().__init__()

        self.device = device
        self.confidence = confidence
        self.quantile_num = quantile_num
        self.method = method

        feature_dim = envs.observation_space.shape[1]-1
        output_dim = envs.action_space.shape[-1] * self.quantile_num
        seq_len = window_size

        norm_layer = partial(nn.LayerNorm, eps=1e-6)
        actor_modules = self.create_modules(
            feature_dim=feature_dim,
            embed_dim=embed_dim,
            seq_len=seq_len,
            middle_dim=middle_dim,
            output_dim=output_dim,
            depth=depth,
            norm_layer=norm_layer
        )
        critic_modules = self.create_modules(
            feature_dim=feature_dim,
            embed_dim=embed_dim,
            seq_len=seq_len,
            middle_dim=middle_dim,
            output_dim=output_dim,
            depth=depth,
            norm_layer=norm_layer
        )

        self.actor_mean = Actor(encoder_layer=actor_modules['encoder_layer'],
                                middle_layer=actor_modules['middle_layer'],
                                decoder_layer=actor_modules['decoder_layer'],
                                norm=actor_modules['norm'],
                                blocks=actor_modules['blocks'],
                                method=self.method,
                                device=self.device,
                                seq_len=seq_len)


        self.actor_logstd = nn.Parameter(torch.zeros(1, output_dim))

        self.critic = Critic(encoder_layer=critic_modules['encoder_layer'],
                             middle_layer=critic_modules['middle_layer'],
                             decoder_layer=critic_modules['decoder_layer'],
                             norm=critic_modules['norm'],
                             blocks=critic_modules['blocks'],
                             action_dim=output_dim,
                             embed_dim=embed_dim,
                             norm_layer=norm_layer,
                             method=self.method,
                             device=self.device,
                             seq_len=seq_len,
                             )

        self.lower_bound = -5.0
        self.upper_bound = 5.0


    def create_modules(self, feature_dim, embed_dim, seq_len, depth, middle_dim, output_dim,
                       norm_layer=partial(nn.LayerNorm, eps=1e-6), num_head=8, group_size=3, flash=True):

        stem_layer = nn.Linear(feature_dim, embed_dim, bias=True)

        if self.method == "mlp":
            encoder_layer = nn.ModuleDict(
                dict(stem_layer=stem_layer,)
            )

            blocks = nn.Sequential(*[
                MLP(in_features=embed_dim,
                    hidden_features=embed_dim,
                    act_layer=nn.Tanh,
                    out_features=embed_dim)
                for _ in range(depth)
            ])

            norm = norm_layer(embed_dim)

            middle_layer = nn.ModuleDict(
                dict(
                    reduce_seq_layer=nn.Linear(seq_len, middle_dim),
                    reduce_embed_layer=nn.Linear(middle_dim * embed_dim, embed_dim),
                )
            )

            decoder_layer = nn.Linear(embed_dim, output_dim, bias=True)

        elif self.method == "transformer":

            indices_embedding_layer = DiscreteEmbeddingLayer(
                num_max_tokens=seq_len,
                embed_dim=embed_dim,
            )

            encoder_layer = nn.ModuleDict(
                dict(
                    indices=indices_embedding_layer,
                    stem_layer=stem_layer,
                )
            )

            transformer_config = GPTConfig(
                embed_dim=embed_dim,
                num_head=num_head,
                dropout=0.0,
                bias=True,
                block_size=seq_len,
                group_size=group_size,
                flash=flash
            )

            blocks = nn.Sequential(*[
                TransformerBlock(transformer_config) for _ in range(depth)
            ])

            norm = norm_layer(embed_dim)

            middle_layer = nn.Identity()

            decoder_layer = nn.Linear(
                embed_dim,
                output_dim,
                bias=True,
            )

        modules = nn.ModuleDict({
            "encoder_layer": encoder_layer,
            "blocks": blocks,
            "norm": norm,
            "middle_layer": middle_layer,
            "decoder_layer": decoder_layer,
        })

        # Apply weight initialization
        modules.apply(init_weights)

        return modules

    def get_value(self, x):
        return self.critic(x)

    def get_action_and_value(self, x, action=None):
        action_mean = self.actor_mean(x)
        action_logstd = self.actor_logstd.expand_as(action_mean)
        action_std = torch.exp(action_logstd).to(self.device)

        sample, log_prob, entropy, _ = truncated_normal(action_mean,
                                                     action_std,
                                                     self.lower_bound,
                                                     self.upper_bound)
        sample_ = torch.clamp(sample, min=-5.0, max=10.0)
        return sample_, log_prob, entropy, self.critic(x)

    def get_action(self, x):
        action = self.actor_mean(x)
        action_ = torch.clamp(action, min=-5.0, max=10.0)
        return action_